# -*- coding: utf-8 -*-
"""
Created on Sun Sep 29 18:27:30 2024
"""

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

B1_train_time = [4032,4057,4041,4076,4082,4030,4018,4066,4015,4007]
B2_train_time = [4430,4383,4405,4457,4404,4456,4434,4551,5272]
B3_train_time = [3311,3221,3223,3214,3196,3217,3204,3236,3233,3196]
B4_train_time = [4717,4701,4674,4739,4692,4746,4753,4745,4693]
B5_train_time = [4783,4780,4742,4758,4728,4739,4726,4766,4717]
B6_train_time = [4046,4063,4036,4039,4049,4057,4030,4049,4032,3990]
B7_train_time = [4558,4514,4552,4495,4488,4765,5725,5656]
B8_train_time = [3253,3862,3656,3193,3164,3221,3136,3151,3160,3147]
B9_train_time = [5262,5285,5210,5271,5275,5310,5269,5252]
B10_train_time = [4782,4732,4729,4669,4742,4685,4722,4732,4717]
P1_train_time = [1266,1287,1281,1302,1312,1331,1380,1438,1423,1429]
P2_train_time = [1236,1240,1256,1275,1320,1357,1371,1379,1383,1381]
P3_train_time = [1272,1257,1282,1278,1270,1286,1299,1295,1297,1286]
P4_train_time = [1079,1081,1074,1087,1081,1107,1142,1136,1133,1083]
P5_train_time = [1320,1349,1346,1362,1351,1364,1394,1414,1401,1408]
P6_train_time = [1340,1366,1361,1362,1367,1387,1374,1418,1361,1371]
P7_train_time = [1296,1300,1310,1309,1316,1316,1384,1381,1385,1399]
P8_train_time = [1400,1407,1701,1727,1427,1446,1460,1448,1478,1455]
P9_train_time = [1229,1254,1254,1273,1341,1355,1367,1369,1380,1372]
P10_train_time = [1790,1791,1741,1661,1690,1709,1438,1329,1323,1328]

B_train_times = [B1_train_time,B2_train_time,B3_train_time,B4_train_time,B5_train_time,B6_train_time,B7_train_time,B8_train_time,B9_train_time,B10_train_time]
B_train_time_flat = []
B_category = []
for i in B_train_times:
    for j in i:
        B_train_time_flat.append(j)
        B_category.append('Baseline')
P_train_times = [P1_train_time,P2_train_time,P3_train_time,P4_train_time,P5_train_time,P6_train_time,P7_train_time,P8_train_time,P9_train_time,P10_train_time]
P_train_time_flat = []
P_category = []
for i in P_train_times:
    for j in i:
        P_train_time_flat.append(j)
        P_category.append('Proposed method')
        
# Data for plotting
df = pd.DataFrame({
    'Method': B_category+P_category,
    'Training Time (seconds)': B_train_time_flat+P_train_time_flat
})

B_train_time_flat_df = pd.DataFrame(B_train_time_flat)
P_train_time_flat_df = pd.DataFrame(P_train_time_flat)

# Calculate statistics for each column
B_train_time_stats = pd.DataFrame({
    'mean': B_train_time_flat_df.mean(),
    'std': B_train_time_flat_df.std(),
    'min': B_train_time_flat_df.min(),
    'max': B_train_time_flat_df.max(),
    'median': B_train_time_flat_df.median()
})
P_train_time_stats = pd.DataFrame({
    'mean': P_train_time_flat_df.mean(),
    'std': P_train_time_flat_df.std(),
    'min': P_train_time_flat_df.min(),
    'max': P_train_time_flat_df.max(),
    'median': P_train_time_flat_df.median()
})

print(B_train_time_stats)
print(P_train_time_stats)

# ['#0173b2',
#  '#de8f05',
#  '#029e73',
#  '#d55e00',
#  '#cc78bc',
#  '#ca9161',
#  '#fbafe4',
#  '#949494',
#  '#ece133',
#  '#56b4e9']

#%% Bar plot
plt.rcParams['figure.figsize']
plt.rcParams['axes.titley'] = 1.05    # y is in axes-relative coordinates.
# [6.4, 4.8]
# print(sns.color_palette().as_hex())
# blue: #1f77b4
# orange: #ff7f0e
fig3, ax3 = plt.subplots(figsize=(7, 5.56))
ax = sns.barplot(data=df, x="Method", y="Training Time (seconds)",saturation=0.5,width=0.5,palette=['#0173b2', '#de8f05'], ax=ax3, errorbar=None)
ax.bar_label(ax.containers[0], fontsize=20);
sns.stripplot(data=df, x="Method", y="Training Time (seconds)",palette=['#ffb482', '#a1c9f4'], size=1, alpha=0.5, ax=ax3)
# iterate through the axes containers
for c in ax3.xaxis.get_major_ticks():
    c.label.set_fontsize(15)
for c in ax3.yaxis.get_major_ticks():
    c.label.set_fontsize(15)
ax.set_title('Comparison of Training\n Time per epoch', fontsize=30)
plt.ylabel('Average training time \nper epoch (seconds)',fontsize=20, weight='bold')
plt.xlabel('Methods',fontsize=20, weight='bold')
plt.grid()
plt.savefig("bar plot.png", bbox_inches='tight')
plt.show()

